package requestmgr

import (
	"context"
	"testing"

	imageMock "github.com/stackrox/rox/central/image/datastore/mocks"
	imageV2Mocks "github.com/stackrox/rox/central/imagev2/datastore/mocks"
	vulnReqCacheMocks "github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/cache/mocks"
	"github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/common"
	vulnReqMocks "github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/datastore/mocks"
	"github.com/stackrox/rox/generated/storage"
	"github.com/stackrox/rox/pkg/features"
	"github.com/stackrox/rox/pkg/fixtures"
	"github.com/stretchr/testify/suite"
	"go.uber.org/mock/gomock"
)

func TestVulnRequestManagerCacheOpsWithUnifiedDeferral(t *testing.T) {
	suite.Run(t, new(VulnRequestManagerCacheTestSuite))
}

type VulnRequestManagerCacheTestSuite struct {
	suite.Suite

	mockCtrl *gomock.Controller
	ctx      context.Context

	vulnReqDataStore *vulnReqMocks.MockDataStore
	imageDataStore   *imageMock.MockDataStore
	imageV2DataStore *imageV2Mocks.MockDataStore
	manager          Manager
	pendingReqCache  *vulnReqCacheMocks.MockVulnReqCache
	activeReqCache   *vulnReqCacheMocks.MockVulnReqCache
}

func (s *VulnRequestManagerCacheTestSuite) SetupSuite() {
	s.ctx = context.Background()
	s.mockCtrl = gomock.NewController(s.T())

	s.vulnReqDataStore = vulnReqMocks.NewMockDataStore(s.mockCtrl)
	s.imageDataStore = imageMock.NewMockDataStore(s.mockCtrl)
	s.imageV2DataStore = imageV2Mocks.NewMockDataStore(s.mockCtrl)
	s.pendingReqCache = vulnReqCacheMocks.NewMockVulnReqCache(s.mockCtrl)
	s.activeReqCache = vulnReqCacheMocks.NewMockVulnReqCache(s.mockCtrl)
	s.manager = New(nil, s.vulnReqDataStore, s.pendingReqCache, s.activeReqCache, s.imageDataStore, s.imageV2DataStore, nil, nil, nil, nil)
}

func (s *VulnRequestManagerCacheTestSuite) TestAdd() {
	globalReq := fixtures.GetGlobalDeferralRequestV2("cve1")

	imageScopeAllTagsReq := fixtures.GetImageScopeDeferralRequest("reg", "remote", ".*", "cve1")
	imageScopeOneTagsReq := fixtures.GetImageScopeDeferralRequest("reg", "remote", "tag", "cve1")

	var enforcedRequests []*storage.VulnerabilityRequest
	for _, req := range []*storage.VulnerabilityRequest{imageScopeOneTagsReq, imageScopeAllTagsReq, globalReq} {
		req.Id = ""

		s.vulnReqDataStore.EXPECT().SearchRawRequests(gomock.Any(), gomock.Any()).Return(enforcedRequests, nil)
		s.vulnReqDataStore.EXPECT().AddRequest(gomock.Any(), gomock.Any()).Return(nil)
		s.pendingReqCache.EXPECT().Add(req)
		s.NoError(s.manager.Create(allAccessCtx, req))

		req.Status = storage.RequestStatus_APPROVED
		enforcedRequests = append(enforcedRequests, req)
	}
}

func (s *VulnRequestManagerCacheTestSuite) TestAddConflicts() {
	globalReq := fixtures.GetGlobalDeferralRequestV2("cve1")

	imageScopeAllTagsReq := fixtures.GetImageScopeDeferralRequest("reg", "remote", ".*", "cve1")
	imageScopeOneTagsReq := fixtures.GetImageScopeDeferralRequest("reg", "remote", "tag", "cve1")

	var enforcedRequests []*storage.VulnerabilityRequest
	for idx, req := range []*storage.VulnerabilityRequest{globalReq, imageScopeAllTagsReq, imageScopeOneTagsReq} {
		req.Id = ""

		s.vulnReqDataStore.EXPECT().SearchRawRequests(gomock.Any(), gomock.Any()).Return(enforcedRequests, nil)
		if idx == 0 {
			s.vulnReqDataStore.EXPECT().AddRequest(gomock.Any(), gomock.Any()).Return(nil)
			s.pendingReqCache.EXPECT().Add(req)
			s.NoError(s.manager.Create(allAccessCtx, req))
		} else {
			s.Error(s.manager.Create(allAccessCtx, req))
		}

		req.Status = storage.RequestStatus_APPROVED
		enforcedRequests = append(enforcedRequests, req)
	}
}

func (s *VulnRequestManagerCacheTestSuite) TestApprove() {
	// Test unexpired request.
	expected := fixtures.GetImageScopeDeferralRequest("r", "r", "g", "cve")
	s.vulnReqDataStore.EXPECT().UpdateRequestStatus(allAccessCtx, expected.GetId(), "approved", storage.RequestStatus_APPROVED).
		Return(expected, nil)
	s.vulnReqDataStore.EXPECT().SearchRawRequests(gomock.Any(), gomock.Any()).Return(nil, nil)
	expected.Status = storage.RequestStatus_APPROVED
	s.pendingReqCache.EXPECT().Remove(expected.GetId())
	s.activeReqCache.EXPECT().Add(expected)
	if features.FlattenImageData.Enabled() {
		s.imageV2DataStore.EXPECT().Search(allAccessCtx, gomock.Any()).Return(nil, nil)
	} else {
		s.imageDataStore.EXPECT().Search(allAccessCtx, gomock.Any()).Return(nil, nil)
	}

	req, err := s.manager.Approve(allAccessCtx, expected.GetId(), &common.VulnRequestParams{Comment: "approved"})
	s.NoError(err)
	s.Equal(expected.GetId(), req.GetId())

	// Test expired request does not cause cache updates (snooze workflow).
	expected.Expired = true
	s.vulnReqDataStore.EXPECT().UpdateRequestStatus(allAccessCtx, expected.GetId(), "approved", storage.RequestStatus_APPROVED).
		Return(expected, nil)
	s.vulnReqDataStore.EXPECT().SearchRawRequests(gomock.Any(), gomock.Any()).Return(nil, nil)

	req, err = s.manager.Approve(allAccessCtx, expected.GetId(), &common.VulnRequestParams{Comment: "approved"})
	s.Error(err)
	s.Nil(req)
}

func (s *VulnRequestManagerCacheTestSuite) TestDeny() {
	expected := fixtures.GetImageScopeDeferralRequest("r", "r", "g", "cve")
	s.vulnReqDataStore.EXPECT().UpdateRequestStatus(allAccessCtx, expected.GetId(), "denied", storage.RequestStatus_DENIED).
		Return(expected, nil)
	expected.Status = storage.RequestStatus_DENIED
	s.pendingReqCache.EXPECT().Remove(expected.GetId())

	req, err := s.manager.Deny(allAccessCtx, expected.GetId(), &common.VulnRequestParams{Comment: "denied"})
	s.NoError(err)
	s.Equal(expected.GetId(), req.GetId())
}

func (s *VulnRequestManagerCacheTestSuite) TestDelete() {
	expected := fixtures.GetImageScopeDeferralRequest("r", "r", "g", "cve")
	s.vulnReqDataStore.EXPECT().RemoveRequest(allAccessCtx, expected.GetId()).Return(nil)
	s.pendingReqCache.EXPECT().Remove(expected.GetId())

	s.NoError(s.manager.Delete(allAccessCtx, expected.GetId()))
}

func (s *VulnRequestManagerCacheTestSuite) TestUndo() {
	expected := fixtures.GetImageScopeDeferralRequest("r", "r", "g", "cve")
	s.vulnReqDataStore.EXPECT().MarkRequestInactive(allAccessCtx, expected.GetId(), gomock.Any()).Return(expected, nil)
	s.pendingReqCache.EXPECT().Remove(expected.GetId())
	s.activeReqCache.EXPECT().Remove(expected.GetId())
	if features.FlattenImageData.Enabled() {
		s.imageV2DataStore.EXPECT().Search(allAccessCtx, gomock.Any()).Return(nil, nil)
	} else {
		s.imageDataStore.EXPECT().Search(allAccessCtx, gomock.Any()).Return(nil, nil)
	}

	req, err := s.manager.Undo(allAccessCtx, expected.GetId(), &common.VulnRequestParams{})
	s.NoError(err)
	s.Equal(expected.GetId(), req.GetId())
}

func (s *VulnRequestManagerCacheTestSuite) TestUpdateExpiry() {
	// Test unexpired request.
	expected := fixtures.GetImageScopeDeferralRequest("r", "r", "g", "cve")
	s.vulnReqDataStore.EXPECT().UpdateRequestExpiry(allAccessCtx, expected.GetId(), "update", &storage.RequestExpiry{}).
		Return(expected, nil)
	s.pendingReqCache.EXPECT().Add(expected)

	req, err := s.manager.UpdateExpiry(allAccessCtx, expected.GetId(), &common.VulnRequestParams{
		Comment: "update",
		Expiry:  &storage.RequestExpiry{},
	})
	s.NoError(err)
	s.Equal(expected.GetId(), req.GetId())
}
